{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载基础模型和lora权重，并使用gradio进行推理，值得注意的是基础模型和微调后得到的lora权重可以不合并成一个微调模型，这样的话基础模型则可以灵活的与不同的lora权重进行推理，处理不同任务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import fire\n",
    "import gradio as gr\n",
    "import torch\n",
    "import transformers\n",
    "from peft import PeftModel\n",
    "from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer\n",
    "\n",
    "from utils.callbacks import Iteratorize, Stream\n",
    "from utils.prompter import Prompter\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "\n",
    "try:\n",
    "    if torch.backends.mps.is_available():\n",
    "        device = \"mps\"\n",
    "except:  # noqa: E722\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化参数\n",
    "load_8bit: bool = False\n",
    "base_model: str = \"models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16\" # 对应微调代码的基础模型，手动下载\n",
    "lora_weights: str = \"alpaca-lora-7b\" # 微调后的lora权重，也可以手动下载\n",
    "prompt_template: str = \"\"  # The prompt template to use, will default to alpaca.\n",
    "server_name: str = \"0.0.0.0\"  # Allows to listen on all interfaces by providing '0.\n",
    "share_gradio: bool = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message\n"
     ]
    }
   ],
   "source": [
    "# 加载基础模型和lora权重\n",
    "base_model = base_model or os.environ.get(\"BASE_MODEL\", \"\")\n",
    "assert (\n",
    "    base_model\n",
    "), \"Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'\"\n",
    "\n",
    "prompter = Prompter(prompt_template)    # 加载prompt模板\n",
    "tokenizer = LlamaTokenizer.from_pretrained(base_model)  #加载LlamaTokenizer分词器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.98s/it]\n"
     ]
    }
   ],
   "source": [
    "if device == \"cuda\":\n",
    "    model = LlamaForCausalLM.from_pretrained(   #加载基础模型\n",
    "        base_model,\n",
    "        load_in_8bit=load_8bit,\n",
    "        torch_dtype=torch.float16,\n",
    "        device_map=\"auto\",\n",
    "    )\n",
    "    model = PeftModel.from_pretrained(  #加载lora权重\n",
    "        model,\n",
    "        lora_weights,   # 值得注意的是需要观察peft库和transformers库的版本，以及loraconfig的版本，https://huggingface.co/docs/peft/tutorial/peft_model_config?config=LoraConfig\n",
    "        torch_dtype=torch.float16,\n",
    "    )\n",
    "elif device == \"mps\":\n",
    "    model = LlamaForCausalLM.from_pretrained(\n",
    "        base_model,\n",
    "        device_map={\"\": device},\n",
    "        torch_dtype=torch.float16,\n",
    "    )\n",
    "    model = PeftModel.from_pretrained(\n",
    "        model,\n",
    "        lora_weights,\n",
    "        device_map={\"\": device},\n",
    "        torch_dtype=torch.float16,\n",
    "    )\n",
    "else:\n",
    "    model = LlamaForCausalLM.from_pretrained(\n",
    "        base_model, device_map={\"\": device}, low_cpu_mem_usage=True\n",
    "    )\n",
    "    model = PeftModel.from_pretrained(\n",
    "        model,\n",
    "        lora_weights,\n",
    "        device_map={\"\": device},\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 配置模型\n",
    "model.config.pad_token_id = tokenizer.pad_token_id = 0 # 配置模型的特殊标记ID\n",
    "model.config.bos_token_id = 1\n",
    "model.config.eos_token_id = 2\n",
    "\n",
    "if not load_8bit:\n",
    "    model.half()    # 半精度16float\n",
    "\n",
    "model.eval() # 设置模型为评估模式\n",
    "if torch.__version__ >= \"2\" and sys.platform != \"win32\":\n",
    "    model = torch.compile(model) # torch编译，优化性能\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义eval函数\n",
    "def evaluate(\n",
    "    instruction,    # 指令\n",
    "    input=None, # 补充输入（可选\n",
    "    temperature=0.1, # 温度（控制文本随机性，越高越丰富\n",
    "    top_p=0.75, # 较高的值会使用更广泛的词汇\n",
    "    top_k=40,   #生成时考虑最高概率的K个词汇\n",
    "    num_beams=4,    #   使用束搜索生成时的束数量。较高的值会生成更多候选答案。\n",
    "    max_new_tokens=128, #\n",
    "    stream_output=False,    # 流式输出\n",
    "    **kwargs,\n",
    "):\n",
    "    prompt = prompter.generate_prompt(instruction, input)   # 将指令和输入组合成一个完整的提示\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\") # 将提示转换为模型输入格式（张量\n",
    "    input_ids = inputs[\"input_ids\"].to(device)\n",
    "    generation_config = GenerationConfig(   # 生成配置对象\n",
    "        temperature=temperature,\n",
    "        top_p=top_p,\n",
    "        top_k=top_k,\n",
    "        num_beams=num_beams,\n",
    "        **kwargs,\n",
    "    )\n",
    "\n",
    "    generate_params = { # 生成参数\n",
    "        \"input_ids\": input_ids,\n",
    "        \"generation_config\": generation_config,\n",
    "        \"return_dict_in_generate\": True,\n",
    "        \"output_scores\": True,\n",
    "        \"max_new_tokens\": max_new_tokens,\n",
    "    }\n",
    "\n",
    "    if stream_output:   # 流式输出\n",
    "        def generate_with_callback(callback=None, **kwargs):\n",
    "            kwargs.setdefault(\"stopping_criteria\", transformers.StoppingCriteriaList())\n",
    "            kwargs[\"stopping_criteria\"].append(Stream(callback_func=callback))\n",
    "            with torch.no_grad():\n",
    "                model.generate(**kwargs)\n",
    "\n",
    "        def generate_with_streaming(**kwargs):\n",
    "            return Iteratorize(generate_with_callback, kwargs, callback=None)\n",
    "\n",
    "        with generate_with_streaming(**generate_params) as generator:\n",
    "            for output in generator:\n",
    "                decoded_output = tokenizer.decode(output)\n",
    "                if output[-1] in [tokenizer.eos_token_id]:  # 判断是否生成结束\n",
    "                    break\n",
    "                yield prompter.get_response(decoded_output)\n",
    "        return\n",
    "\n",
    "    with torch.no_grad():   # 上下文管理器禁用梯度计算\n",
    "        generation_output = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            generation_config=generation_config,\n",
    "            return_dict_in_generate=True,\n",
    "            output_scores=True,\n",
    "            max_new_tokens=max_new_tokens,\n",
    "        )\n",
    "    s = generation_output.sequences[0]\n",
    "    output = tokenizer.decode(s)\n",
    "    yield prompter.get_response(output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://0.0.0.0:7860\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://localhost:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:540: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:545: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.75` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:562: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `40` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:540: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:545: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.75` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n",
      "/usr/local/anaconda3/envs/ChatTTS/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:562: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `40` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# 使用gradio创建接口\n",
    "gr.Interface(\n",
    "    fn=evaluate,\n",
    "    inputs=[\n",
    "        gr.components.Textbox(\n",
    "            lines=2,\n",
    "            label=\"Instruction\",\n",
    "            placeholder=\"Tell me about alpacas.\",\n",
    "        ),\n",
    "        gr.components.Textbox(lines=2, label=\"Input\", placeholder=\"none\"),\n",
    "        gr.components.Slider(\n",
    "            minimum=0, maximum=1, value=0.1, label=\"Temperature\"\n",
    "        ),\n",
    "        gr.components.Slider(\n",
    "            minimum=0, maximum=1, value=0.75, label=\"Top p\"\n",
    "        ),\n",
    "        gr.components.Slider(\n",
    "            minimum=0, maximum=100, step=1, value=40, label=\"Top k\"\n",
    "        ),\n",
    "        gr.components.Slider(\n",
    "            minimum=1, maximum=4, step=1, value=4, label=\"Beams\"\n",
    "        ),\n",
    "        gr.components.Slider(\n",
    "            minimum=1, maximum=2000, step=1, value=128, label=\"Max tokens\"\n",
    "        ),\n",
    "        gr.components.Checkbox(label=\"Stream output\"),\n",
    "    ],\n",
    "    outputs=[\n",
    "        gr.components.Textbox(\n",
    "            lines=5,\n",
    "            label=\"Output\",\n",
    "        )\n",
    "    ],\n",
    "    title=\"🦙🌲 Alpaca-LoRA\",\n",
    "    description=\"Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).\",\n",
    ").queue().launch(server_name=\"0.0.0.0\", share=share_gradio)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ChatTTS",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
